/*
 * JBoss, Home of Professional Open Source
 * Copyright 2010, Red Hat, Inc. and individual contributors
 * by the @authors tag. See the copyright.txt in the distribution for a
 * full listing of individual contributors.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */

package org.jdiameter.client.impl.annotation;

import org.jdiameter.api.*;
import org.jdiameter.api.annotation.*;
import org.jdiameter.client.impl.annotation.internal.ClassInfo;
import org.jdiameter.client.impl.annotation.internal.ConstructorInfo;
import org.jdiameter.client.impl.annotation.internal.MethodInfo;
import org.jdiameter.client.impl.annotation.internal.Storage;
import org.jdiameter.client.impl.RawSessionImpl;
import org.jdiameter.client.api.annotation.IRecoder;
import org.jdiameter.client.api.annotation.RecoderException;
import org.jdiameter.client.api.IMessage;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.*;

/**
 * 
 * @author erick.svenson@yahoo.com
 * @author <a href="mailto:brainslog@gmail.com"> Alexandre Mendonca </a>
 * @author <a href="mailto:baranowb@gmail.com"> Bartosz Baranowski </a>
 */
public class Recoder implements IRecoder {

  // TODO full min/max/position constrains and optimization (caching) 

  private static final Logger log = LoggerFactory.getLogger(Recoder.class);

  private Storage storage = new Storage();
  private final RawSessionImpl rawSession;
  private final MetaData metaData;

  public Recoder(SessionFactory factory, MetaData metaData) {
    this.metaData = metaData;
    try {
      this.rawSession = (RawSessionImpl) factory.getNewRawSession();
    } catch (InternalException e) {
      throw new IllegalArgumentException(e);
    }
  }
  // =======================================================================================


  //@Override
  public Message encodeToRequest(Object yourDomainMessageObject, Avp... additionalAvp) throws RecoderException {
    return encode(yourDomainMessageObject, null, 0, additionalAvp);
  }

  //@Override
  public Message encodeToAnswer(Object yourDomainMessageObject, Request request, long resultCode) throws RecoderException {
    return encode(yourDomainMessageObject, request,  resultCode);
  }

  public Message encode(Object yourDomainMessageObject, Request request, long resultCode, Avp... addAvp) throws RecoderException {
    IMessage message = null;
    ClassInfo classInfo = storage.getClassInfo(yourDomainMessageObject.getClass());
    CommandDscr commandDscr = classInfo.getAnnotation(CommandDscr.class);
    if (commandDscr != null) {
      // Get command parameters
      if (request == null) {
        message = (IMessage) rawSession.createMessage(commandDscr.code(), ApplicationId.createByAccAppId(0));
        message.setRequest(true);
        message.getAvps().addAvp(addAvp);
        try {
          if (message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID) != null) {
            message.setHeaderApplicationId(message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID).getUnsigned32());
          }
          else if (message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID) != null) {
            message.setHeaderApplicationId(message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID).getUnsigned32());
          }
          else if (message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID) != null) {
            message.setHeaderApplicationId(message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID).
                getGrouped().getAvp(Avp.VENDOR_ID).getUnsigned32());       
          }
        } catch (Exception exc) {
          throw new RecoderException(exc);
        }                
        if (message.getAvps().getAvp(Avp.ORIGIN_HOST) == null) {
          message.getAvps().addAvp(Avp.ORIGIN_HOST, metaData.getLocalPeer().getUri().getFQDN(), true, false, true);
        }
        if (message.getAvps().getAvp(Avp.ORIGIN_REALM) == null) {
          message.getAvps().addAvp(Avp.ORIGIN_REALM, metaData.getLocalPeer().getRealmName(), true, false, true);
        }
      } else {
        message = (IMessage) request.createAnswer(resultCode);
      }
      for (CommandFlag f : commandDscr.flags()) {
        switch (f) {
          case E:
            message.setError(true);
            break;
          case P:
            message.setProxiable(true);
            break;
          case R:
            message.setRequest(true);
            break;
          case T:
            message.setReTransmitted(true);
            break;
        }
      }
      // Find top level avp in getter-annotation methods
      Map<String, Object> chMap = getChildInstance(yourDomainMessageObject, classInfo, null);
      // Fill
      for (Child ch : commandDscr.childs()) {
        fillChild(message.getAvps(), ch, chMap);
      }
    } else {
      log.debug("Can not found annotation for object {}", yourDomainMessageObject);
    }
    return message;
  }

  private Map<String, Object> getChildInstance(Object yourDomainMessageObject, ClassInfo c, Map<String, Object> chMap) 
      throws RecoderException {
    if (chMap == null) 
      chMap = new HashMap<String, Object>();  
    for (MethodInfo mi : c.getMethodsInfo()) {
      if (mi.getAnnotation(Getter.class) != null) {
        try {
          Object value = mi.getMethod().invoke(yourDomainMessageObject);
          if (value != null) {
            Class mc = value.getClass().isArray() ? value.getClass().getComponentType() : value.getClass();
            chMap.put(mc.getName(), value);
            for (Class<?> i : mc.getInterfaces())
              chMap.put(i.getName(), value);
          }
        } catch (IllegalAccessException e) {
          throw new RecoderException(e);
        } catch (InvocationTargetException e) {
          throw new RecoderException(e);
        }
      }
    }
    return chMap;
  }

  private void fillChild(AvpSet as, Child ci, Map<String, Object> childs) throws RecoderException {
    Object c = childs.get(ci.ref().getName());
    if (c != null) { 
      ClassInfo cc = storage.getClassInfo(ci.ref());
      AvpDscr ad = cc.getAnnotation(AvpDscr.class);
      if (ad != null) {
        boolean m = false, p = false;
        // cast <=> getter for primitive
        switch (ad.type()) {
          case Integer32: 
          case Enumerated: { 
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) { 
                m = true;
              } 
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            // find in getter
            Collection<Integer> cv = getValue(c, Integer.class);
            for (Integer v : cv)
              as.addAvp(ad.code(), v, ad.vendorId(), m, p);
          }
          break;
          case Unsigned32: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<Long> cv = getValue(c, Long.class);
            for (Long v : cv)
              as.addAvp(ad.code(), v, ad.vendorId(), m, p, true);
          }
          break;              
          case Unsigned64:  
          case Integer64: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<Long> cv = getValue(c, Long.class); 
            for (Long v : cv)
              as.addAvp(ad.code(), v, ad.vendorId(), m, p);
          }
          break;    
          case Float32: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<Float> cv = getValue(c, Float.class);
            for (Float v : cv)
              as.addAvp(ad.code(), v, ad.vendorId(), m, p);
          }
          break;            
          case Float64: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<Double> cv = getValue(c, Double.class);
            for (Double v : cv)
              as.addAvp(ad.code(), v, ad.vendorId(), m, p);
          }
          break;    
          case OctetString: 
          case Address:
          case Time:
          case DiameterIdentity:  
          case DiameterURI:
          case IPFilterRule:
          case QoSFilterRule: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<String> cv = getValue(c, String.class);
            for (String v : cv) 
              as.addAvp(ad.code(), v, ad.vendorId(), m, p, true);
          }
          break;
          case UTF8String: {
            for (AvpFlag f : ad.must()) 
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            Collection<String> cv = getValue(c, String.class);
            for (String v : cv)               
              as.addAvp(ad.code(), v, ad.vendorId(), m, p, false);
          }
          break;            
          case Grouped: {
            for (AvpFlag f : ad.must()) {
              if (AvpFlag.M.equals(f)) {
                m = true;
              }
              else if (AvpFlag.P.equals(f)) {
                p = true;
              }
            }
            Collection<Object> cv = new ArrayList<Object>();
            if (c.getClass().isArray()) {
              cv = Arrays.asList((Object[])c);
            }
            else {
              cv.add(c);
            }
            for (Object cj : cv) {
              AvpSet las = as.addGroupedAvp(ad.code(),ad.vendorId(), m, p);               
              Map<String, Object> lchilds = getChildInstance(cj, storage.getClassInfo(cj.getClass()), null);            
              for (Child lci : ad.childs()) {
                fillChild(las, lci, lchilds);
              }
            }
          }
          break;
        }       
      }     
    } 
  }

  private <T> Collection<T> getValue(Object ic, Class<T> type) throws RecoderException {
    Collection<T> rc = new ArrayList<T>();
    Object[] xc = null;
    if (ic.getClass().isArray()) 
      xc = (Object[]) ic;
    else
      xc = new Object[] {ic};
    for (Object  c : xc) {  
      for (MethodInfo lm : storage.getClassInfo(c.getClass()).getMethodsInfo()) {
        if (lm.getAnnotation(Getter.class) != null) {
          try {
            rc.add((T) lm.getMethod().invoke(c));
          } catch (IllegalAccessException e) {
            throw new RecoderException(e);
          } catch (InvocationTargetException e) {
            throw new RecoderException(e);
          }
        }
      }           
    }
    return rc;
  } 

  // =======================================================================================

  public <T> T decode(Message message, java.lang.Class<T> yourDomainMessageObject) throws RecoderException {
    Object rc = null;
    ClassInfo c = storage.getClassInfo(yourDomainMessageObject);
    CommandDscr cd = c.getAnnotation(CommandDscr.class);
    if (cd != null) {
      try {
        if (message.getCommandCode() != cd.code())
          throw new IllegalArgumentException("Invalid message code " + message.getCommandCode());
        if (message.getApplicationId() != 0 && message.getApplicationId() != cd.appId())
          throw new IllegalArgumentException("Invalid Application-Id " + message.getApplicationId());
        for (CommandFlag f : cd.flags()) {
          switch (f) {
            case E:
              if (!message.isError())
                throw new IllegalArgumentException("Flag e is not set");
              break;
            case P:
              if (!message.isProxiable())
                throw new IllegalArgumentException("Flag p is not set");
              break;
            case R:
              if (!message.isRequest())
                throw new IllegalArgumentException("Flag m is not set");
              break;
            case T:
              if (!message.isReTransmitted())
                throw new IllegalArgumentException("Flag t is not set");
              break;
          }
        }
        // Find max constructor + lost avp set by setters
        int cacount = 0;
        Constructor<?> cm = null;
        Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>();
        for (ConstructorInfo ci : c.getConstructorsInfo()) {
          if (ci.getAnnotation(Setter.class) != null) {
            // check params - all params must have avp annotation
            Class<?>[] params = ci.getConstructor().getParameterTypes();
            boolean correct = true;
            for (Class<?> j : params) {
              if (j.isArray())
                j = j.getComponentType();
              if (storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) {
                correct = false;
                break;
              }
            }
            if (!correct)
              continue;
            // find max args constructor
            if (cacount < params.length) {
              cacount = params.length;
              cm = ci.getConstructor();
            }
          }
        }
        // fill cm args
        List<Object> initargs = new ArrayList<Object>();
        if (cm != null) {
          for (Class<?> ac : cm.getParameterTypes()) {
            Class<?> lac = ac.isArray() ? ac.getComponentType() : ac;
            cmargs.put(lac.getName(), ac);
            // Create params
            initargs.add(createChildByAvp(findChildDscr(cd.childs(), ac), ac, message.getAvps()));
          }
          // Create instance class
          rc = cm.newInstance(initargs.toArray());
        } else {
          rc = yourDomainMessageObject.newInstance();
        }
        //
        for (MethodInfo mi : c.getMethodsInfo()) {
          if (mi.getAnnotation(Setter.class) != null) {
            Class<?>[] pt = mi.getMethod().getParameterTypes();
            if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) {
              Class<?> ptc = pt[0].isArray() ? pt[0].getComponentType() : pt[0];
              if (!cmargs.containsKey(ptc.getName())) {
                cmargs.put(ptc.getName(), ptc);
                mi.getMethod().invoke(rc, createChildByAvp(findChildDscr(cd.childs(), pt[0]), pt[0], message.getAvps()));
              }
            }
          }
        }
        // Fill undefined avp
        setUndefinedAvp(message.getAvps(), rc, c, cmargs);
      } catch (InstantiationException e) {
        throw new RecoderException(e);
      } catch (InvocationTargetException e) {
        throw new RecoderException(e);
      } catch (IllegalAccessException e) {
        throw new RecoderException(e);
      }
    }
    return (T) rc;
  }

  private void setUndefinedAvp(AvpSet set, Object rc, ClassInfo c, Map<String, Class<?>> cmargs) throws RecoderException {
    try {
      for (MethodInfo mi : c.getMethodsInfo()) {
        Setter s = mi.getAnnotation(Setter.class);
        if (s != null && Setter.Type.UNDEFINED.equals(s.value())) {
          Map<Integer, Integer> known = new HashMap<Integer, Integer>();
          for (Class<?> argc : cmargs.values()) {
            AvpDscr argd = storage.getClassInfo((argc.isArray() ? argc.getComponentType() : argc)).getAnnotation(AvpDscr.class);
            known.put(argd.code(), argd.code());
          }
          for (Avp a : set) {
            if (!known.containsKey(a.getCode()))
              mi.getMethod().invoke(rc, new UnknownAvp(a.getCode(), a.isMandatory(), a.isVendorId(), a.isEncrypted(), a.getVendorId(), a.getRaw()));
          }
          break;
        }
      }
    } catch (IllegalAccessException e) {
      throw new RecoderException(e);
    } catch (InvocationTargetException e) {
      throw new RecoderException(e);
    } catch (AvpDataException e) {
      throw new RecoderException(e);
    }
  }

  private Child findChildDscr(Child[] childs, Class<?> m) {
    for (Child c : childs) {
      Class<?> t = c.ref();
      m = m.isArray() ? m.getComponentType() : m;
      if (m == t)
        return c;
      if (m.getSuperclass() == t)
        return c;
      for (Class<?> i : m.getInterfaces()) 
        if (i == t)
          return c;
    }
    return null;
  }

  private Object createChildByAvp(Child mInfo, Class<?> m, AvpSet parentSet)  throws RecoderException {
    Object rc;    
    AvpDscr ad = storage.getClassInfo((m.isArray() ? m.getComponentType() : m)).getAnnotation(AvpDscr.class);     
    Avp av = parentSet.getAvp(ad.code());
    if (av != null) {     
      for (AvpFlag i : ad.must())
        switch (i) {
          case M:
            if (!av.isMandatory()) throw new IllegalArgumentException("not set flag M");
            break;
          case V:
            if (!av.isVendorId()) throw new IllegalArgumentException("not set flag V");
            break;
          case P:
            if (!av.isEncrypted()) throw new IllegalArgumentException("not set flag P");
            break;
        }
    } else {      
      if (mInfo.min() > 0)
        throw new IllegalArgumentException("Avp " + ad.code() + " is mandatory");       
    }

    if (AvpType.Grouped.equals(ad.type())) {
      if (m.isArray()) {
        Class<?> arrayClass = m.getComponentType();       
        AvpSet as = parentSet.getAvps(ad.code());       
        Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size());       
        for (int ii = 0; ii < array.length; ii++) {
          array[ii] = newInstanceGroupedAvp(arrayClass, ad, as.getAvpByIndex(ii));
        }       
        rc = array;       
      } else {
        rc = newInstanceGroupedAvp(m, ad, parentSet.getAvp(ad.code()));
      }
    } else {
      if (m.isArray()) {
        Class<?> arrayClass = m.getComponentType();       
        AvpSet as = parentSet.getAvps(ad.code());       
        Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size());       
        for (int ii = 0; ii < array.length; ii++) {
          array[ii] = newInstanceSimpleAvp(arrayClass, ad, as.getAvpByIndex(ii));
        }       
        rc = array;       
      } else {
        rc = newInstanceSimpleAvp(m, ad, parentSet.getAvp(ad.code()));
      }

    }
    // =========
    return rc;
  }

  private Object newInstanceGroupedAvp(Class<?> m, AvpDscr ad, Avp avp) throws RecoderException {
    Object rc;
    int cacount = 0;
    ClassInfo c = storage.getClassInfo(m);
    Constructor<?> cm = null;
    Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>();
    for (ConstructorInfo ci : c.getConstructorsInfo()) {
      if (ci.getAnnotation(Setter.class) != null) {
        // check params - all params must have avp annotation
        Class<?>[] params = ci.getConstructor().getParameterTypes();
        boolean correct = true;           
        for (Class<?> j : params) {
          if (j.isArray())
            j = j.getComponentType();
          if (storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) {
            correct = false;
            break;
          }
        }
        if (!correct)
          continue;
        // find max args constructor
        if (cacount < params.length) {
          cacount = params.length;
          cm = ci.getConstructor();           
        }
      }
    }
    // fill cm args
    try {
      List<Object> initargs = new ArrayList<Object>();
      if (cm != null) {
        for (Class<?> ac : cm.getParameterTypes()) {
          Class<?> lac = ac.isArray() ? ac.getComponentType() : ac;
          cmargs.put(lac.getName(), ac);
          // Create params
          initargs.add(createChildByAvp(findChildDscr(ad.childs(), ac), ac, avp.getGrouped()));
        }
        // Create instance class
        rc = cm.newInstance(initargs.toArray());
      } else {
        rc = m.newInstance();
      }
      //
      for (MethodInfo mi : c.getMethodsInfo()) {
        if (mi.getAnnotation(Setter.class) != null) {
          Class<?>[] pt = mi.getMethod().getParameterTypes();
          if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) {
            Class<?> ptc = pt[0].isArray() ? pt[0].getComponentType() : pt[0];
            if (!cmargs.containsKey(ptc.getName())) {
              cmargs.put(ptc.getName(), ptc);
              mi.getMethod().invoke(rc, createChildByAvp(findChildDscr(ad.childs(), pt[0]), pt[0], avp.getGrouped()));
            }
          }
        }
      }
      // Fill undefined child
      setUndefinedAvp(avp.getGrouped(), rc, c, cmargs);
    } catch (InstantiationException e) {
      throw new RecoderException(e);
    } catch (InvocationTargetException e) {
      throw new RecoderException(e);
    } catch (AvpDataException e) {
      throw new RecoderException(e);
    } catch (IllegalAccessException e) {
      throw new RecoderException(e);
    }
    return rc;
  }

  private Object newInstanceSimpleAvp(Class<?> m, AvpDscr ad, Avp avp)  {
    Object rc = null;
    if (avp == null) return null;
    ClassInfo c = storage.getClassInfo(m);
    try {
      for (ConstructorInfo ci : c.getConstructorsInfo()) {
        if (ci.getConstructor().getParameterTypes().length == 1 && ci.getAnnotation(Setter.class) != null) {
          List<Object> args = new ArrayList<Object>();
          if (ci.getConstructor().getParameterTypes()[0].isArray()) {
            args.add(getValue(ad.type(), avp));
          } else {
            args.add(getValue(ad.type(), avp));
          }
          rc = ci.getConstructor().newInstance(args.toArray());
        }
      }
      if (rc == null) {
        rc = m.newInstance();
        for (MethodInfo mi : c.getMethodsInfo()) {
          if (mi.getAnnotation(Setter.class) != null) {
            List<Object> args = new ArrayList<Object>();
            if (mi.getMethod().getParameterTypes()[0].isArray()) {
              args.add(getValue(ad.type(), avp));
            } else {
              args.add(getValue(ad.type(), avp));
            }
            mi.getMethod().invoke(rc, args);
          }
        }
      }
    } catch (InstantiationException e) {
      throw new RecoderException(e);
    } catch (InvocationTargetException e) {
      throw new RecoderException(e);
    } catch (AvpDataException e) {
      throw new RecoderException(e);
    } catch (IllegalAccessException e) {
      throw new RecoderException(e);
    }
    return rc;
  }

  private Object getValue(AvpType type, Avp avp) throws AvpDataException {
    switch (type) {
      case Integer32: 
      case Enumerated: 
        return avp.getInteger32();
      case Unsigned32:        
        return avp.getUnsigned32();             
      case Unsigned64:  
      case Integer64: 
        return avp.getInteger64();    
      case Float32:   
        return avp.getFloat32();            
      case Float64: 
        return avp.getFloat64();  
      case OctetString: 
      case Address:
      case Time:
      case DiameterIdentity:  
      case DiameterURI:
      case IPFilterRule:
      case QoSFilterRule: 
        return avp.getOctetString();
      case UTF8String:  
        return avp.getUTF8String();               
    }       
    return null;
  }

  // =======================================================================================
}